{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This application demonstrates how to build a simple neural network using the Graph mark. \n",
    "Interactions can be enabled by adding event handlers (click, hover etc) on the nodes of the network. \n",
    "See the [Mark Interactions notebook](../Interactions/Mark%20Interactions.ipynb) and the [Scatter Notebook](../Marks/Object%20Model/Scatter.ipynb) for details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import chain, product\n",
    "\n",
    "import numpy as np\n",
    "from bqplot import LinearScale, Graph, Figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NeuralNet(Figure):\n",
    "    def __init__(self, **kwargs):\n",
    "        self.height = kwargs.get('height', 600)\n",
    "        self.width = kwargs.get('width', 960)\n",
    "        self.directed_links = kwargs.get('directed_links', False)\n",
    "        \n",
    "        self.num_inputs = kwargs['num_inputs']\n",
    "        self.num_hidden_layers = kwargs['num_hidden_layers']\n",
    "        self.nodes_output_layer = kwargs['num_outputs']\n",
    "        self.layer_colors = kwargs.get('layer_colors', \n",
    "                                       ['Orange'] * (len(self.num_hidden_layers) + 2))\n",
    "        \n",
    "        self.build_net()\n",
    "        super(NeuralNet, self).__init__(**kwargs)\n",
    "    \n",
    "    def build_net(self):\n",
    "        # create nodes\n",
    "        self.layer_nodes = []\n",
    "        self.layer_nodes.append(['x' + str(i+1) for i in range(self.num_inputs)])\n",
    "        \n",
    "        for i, h in enumerate(self.num_hidden_layers):\n",
    "            self.layer_nodes.append(['h' + str(i+1) + ',' + str(j+1) for j in range(h)])\n",
    "        self.layer_nodes.append(['y' + str(i+1) for i in range(self.nodes_output_layer)])\n",
    "        \n",
    "        self.flattened_layer_nodes = list(chain(*self.layer_nodes))\n",
    "        \n",
    "        # build link matrix\n",
    "        i = 0\n",
    "        node_indices = {}\n",
    "        for layer in self.layer_nodes:\n",
    "            for node in layer:\n",
    "                node_indices[node] = i\n",
    "                i += 1\n",
    "\n",
    "        n = len(self.flattened_layer_nodes)\n",
    "        self.link_matrix = np.empty((n,n))\n",
    "        self.link_matrix[:] = np.nan\n",
    "\n",
    "        for i in range(len(self.layer_nodes) - 1):\n",
    "            curr_layer_nodes_indices = [node_indices[d] for d in self.layer_nodes[i]]\n",
    "            next_layer_nodes = [node_indices[d] for d in self.layer_nodes[i+1]]\n",
    "            for s, t in product(curr_layer_nodes_indices, next_layer_nodes):\n",
    "                self.link_matrix[s, t] = 1\n",
    "        \n",
    "        # set node x locations\n",
    "        self.nodes_x = np.repeat(np.linspace(0, 100, \n",
    "                                             len(self.layer_nodes) + 1, \n",
    "                                             endpoint=False)[1:], \n",
    "                                 [len(n) for n in self.layer_nodes])\n",
    "\n",
    "        # set node y locations\n",
    "        self.nodes_y = np.array([])\n",
    "        for layer in self.layer_nodes:\n",
    "            n = len(layer)\n",
    "            ys = np.linspace(0, 100, n+1, endpoint=False)[1:]\n",
    "            self.nodes_y = np.append(self.nodes_y, ys[::-1])\n",
    "        \n",
    "        # set node colors\n",
    "        n_layers = len(self.layer_nodes)\n",
    "        self.node_colors = np.repeat(np.array(self.layer_colors[:n_layers]), \n",
    "                                     [len(layer) for layer in self.layer_nodes]).tolist()\n",
    "        \n",
    "        xs = LinearScale(min=0, max=100)\n",
    "        ys = LinearScale(min=0, max=100)\n",
    "        \n",
    "        self.graph = Graph(node_data=[{'label': d, \n",
    "                                       'label_display': 'none'} for d in self.flattened_layer_nodes], \n",
    "                           link_matrix=self.link_matrix, \n",
    "                           link_type='line',\n",
    "                           colors=self.node_colors,\n",
    "                           directed=self.directed_links,\n",
    "                           scales={'x': xs, 'y': ys}, \n",
    "                           x=self.nodes_x, \n",
    "                           y=self.nodes_y,\n",
    "                           # color=2 * np.random.rand(len(self.flattened_layer_nodes)) - 1\n",
    "                          )\n",
    "        self.graph.hovered_style = {'stroke': '1.5'}\n",
    "        self.graph.unhovered_style = {'opacity': '0.4'}\n",
    "        \n",
    "        self.graph.selected_style = {'opacity': '1',\n",
    "                                     'stroke': 'red',\n",
    "                                     'stroke-width': '2.5'}\n",
    "        self.marks = [self.graph]\n",
    "        self.title = 'Neural Network'\n",
    "        self.layout.width = str(self.width) + 'px'\n",
    "        self.layout.height = str(self.height) + 'px'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b2d75468c39d4bcf896561f72febb74d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "NeuralNet(fig_margin={'top': 60, 'bottom': 60, 'left': 60, 'right': 60}, layout=Layout(height='600px', width='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "NeuralNet(num_inputs=3, num_hidden_layers=[10, 10, 8, 5], num_outputs=1)"
   ]
  }
 ],
 "metadata": {
  "input_collapsed": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
